package com.chatplus.application.aiprocessor.platform.image.sd;

import com.chatplus.application.aiprocessor.handler.SdWebSocketHandler;
import com.chatplus.application.aiprocessor.platform.image.ImgAiProcessorService;
import com.chatplus.application.aiprocessor.provider.ImgAiProcessorServiceProvider;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.domain.dto.ApiKeyDto;
import com.chatplus.application.domain.dto.TextPromptsDto;
import com.chatplus.application.domain.dto.extend.ImgResultDto;
import com.chatplus.application.domain.entity.draw.SdJobEntity;
import com.chatplus.application.domain.notification.SdStabilityJobNotification;
import com.chatplus.application.domain.request.SdStabilityImageRequest;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.service.account.UserProductLogService;
import com.chatplus.application.service.draw.SdJobService;
import com.chatplus.application.util.ConfigUtil;
import com.chatplus.application.web.notification.NotificationPublisher;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;

/**
 * Sd 绘画处理器
 **/
@Service(value = ImgAiProcessorServiceProvider.SERVICE_NAME_PRE + "SDSTABILITY")
public class SdStabilityImageProcessor extends ImgAiProcessorService {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(SdStabilityImageProcessor.class);
    private final NotificationPublisher notificationPublisher;
    private final SdJobService sdJobService;
    private final SdWebSocketHandler sdWebSocketHandler;
    private final UserProductLogService userProductLogService;

    public SdStabilityImageProcessor(NotificationPublisher notificationPublisher,
                                     SdJobService sdJobService,
                                     SdWebSocketHandler sdWebSocketHandler,
                                     UserProductLogService userProductLogService) {
        this.notificationPublisher = notificationPublisher;
        this.sdJobService = sdJobService;
        this.sdWebSocketHandler = sdWebSocketHandler;
        this.userProductLogService = userProductLogService;
    }

    public ImgResultDto process(Object prompt) {
        SdStabilityImageRequest request = (SdStabilityImageRequest) prompt;
        SdStabilityJobNotification sdJobNotification = new SdStabilityJobNotification();
        if (request == null) {
            return null;
        }
        List<TextPromptsDto> textPrompts = new ArrayList<>();
        String promptText = request.getPrompt();
        String negativePromptText = request.getNegativePrompt();
        TextPromptsDto textPromptsDto = new TextPromptsDto();
        textPromptsDto.setText(promptText);
        textPromptsDto.setWeight(1);
        textPrompts.add(textPromptsDto);
        if (StringUtils.isNotEmpty(negativePromptText)) {
            TextPromptsDto negativePromptsDto = new TextPromptsDto();
            negativePromptsDto.setText(negativePromptText);
            negativePromptsDto.setWeight(-1);
            textPrompts.add(negativePromptsDto);
        }
        request.setTextPrompts(textPrompts);
        ApiKeyDto apiKeyDto = openApiKey.getFirst();
        sdJobNotification.setApiKeyDto(apiKeyDto);
        sdJobNotification.setSdJobId(request.getSdJobId());
        sdJobNotification.setRequest(request);
        notificationPublisher.publish(sdJobNotification);
        sdWebSocketHandler.sendTaskUpdatedMessage(request.getUserId(), 0);
        userProductLogService.reducePower(request.getUserId(), ConfigUtil.getImagePower(request.getUserId(), getChannel(), null), getChannel(), null);
        return null;
    }

    @Override
    public long getRunningJobCount(Long userId) {
        return sdJobService.getRunningJobCount(userId);
    }

    @Override
    public AiPlatformEnum getChannel() {
        return AiPlatformEnum.SD_STABILITY;
    }

    @Override
    public void notifyUpdateTask(Long jobId) {
        SdJobEntity sdJobEntity = sdJobService.getById(jobId);
        sdWebSocketHandler.sendTaskUpdatedMessage(sdJobEntity.getUserId(), sdJobEntity.getProgress());
    }
}
